{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# standard imports\n",
    "import torch\n",
    "from torch_mlir.eager_mode import torch_mlir_tensor"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "# eager mode imports\n",
    "from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor\n",
    "from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "The simplest way of using Eager Mode (through IREE) requires setting a \"backend\":"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"cpu\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "and wrapping all your `torch.Tensor`s:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n"
     ]
    }
   ],
   "source": [
    "NUM_ITERS = 10\n",
    "\n",
    "t = torch.ones((10, 10))\n",
    "u = 2 * torch.ones((10, 10))\n",
    "\n",
    "tt = TorchMLIRTensor(t)\n",
    "print(tt)\n",
    "uu = TorchMLIRTensor(u)\n",
    "print(uu)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "`TorchMLIRTensor` is a \"tensor wrapper subclass\" (more info [here](https://github.com/albanD/subclass_zoo)) that keeps the IREE `DeviceArray` in a field `elem`:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
     ]
    }
   ],
   "source": [
    "for i in range(NUM_ITERS):\n",
    "    yy = tt + uu\n",
    "    print(type(yy))\n",
    "    print(yy.elem.to_host())\n",
    "    yy = tt * uu\n",
    "    print(type(yy))\n",
    "    print(yy.elem.to_host())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "If you have a GPU (and CUDA installed) that works too (you can verify by having `watch -n1 nvidia-smi` up in a terminal while running the next cell):"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
     ]
    }
   ],
   "source": [
    "torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"gpu\")\n",
    "\n",
    "t = torch.ones((10, 10))\n",
    "u = 2 * torch.ones((10, 10))\n",
    "\n",
    "tt = TorchMLIRTensor(t)\n",
    "print(tt)\n",
    "uu = TorchMLIRTensor(u)\n",
    "print(uu)\n",
    "\n",
    "yy = tt + uu\n",
    "print(yy.elem.to_host())\n",
    "yy = tt * uu\n",
    "print(yy.elem.to_host())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "There is a convenience class `AMDSharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
     ]
    }
   ],
   "source": [
    "# eager mode RAII\n",
    "from amdshark.amdshark_runner import AMDSharkEagerMode\n",
    "\n",
    "amdshark_eager_mode = AMDSharkEagerMode(\"cpu\")\n",
    "\n",
    "t = torch.ones((10, 10))\n",
    "u = torch.ones((10, 10))\n",
    "\n",
    "print(t)\n",
    "print(u)\n",
    "\n",
    "for i in range(NUM_ITERS):\n",
    "    yy = t + u\n",
    "    print(type(yy))\n",
    "    print(yy.elem.to_host())\n",
    "    yy = t * u\n",
    "    print(type(yy))\n",
    "    print(yy.elem.to_host())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "The `AMDSharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `AMDSharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `AMDSharkEagerMode`, or switch backends, you need to `del` the instance:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
      " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
      "<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
     ]
    }
   ],
   "source": [
    "del amdshark_eager_mode\n",
    "amdshark_eager_mode = AMDSharkEagerMode(\"cuda\")\n",
    "\n",
    "t = torch.ones((10, 10))\n",
    "u = torch.ones((10, 10))\n",
    "\n",
    "print(t)\n",
    "print(u)\n",
    "\n",
    "yy = t + u\n",
    "print(type(yy))\n",
    "print(yy.elem.to_host())\n",
    "yy = t * u\n",
    "print(type(yy))\n",
    "print(yy.elem.to_host())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}